{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "79651c08",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install open3d plotly\n",
    "\n",
    "from pathlib import Path\n",
    "from tqdm import tqdm\n",
    "import subprocess\n",
    "import numpy as np\n",
    "import open3d\n",
    "from scipy.spatial.distance import cdist\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "import pycolmap\n",
    "from pixsfm.eval.eth3d.utils import Paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "8abcae96",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_path = Path(\"../datasets/ETH3D\")\n",
    "outputs_path = Path(\"../outputs/ETH3D\")\n",
    "\n",
    "tag_raw = \"pixsfm-interp-area\"\n",
    "tag_refined = \"pixsfm-interp-pil-linear\"\n",
    "keypoints = \"sift\"\n",
    "scene = \"terrace\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9fb37cf",
   "metadata": {},
   "source": [
    "# Triangulation error vs track length "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2fdeafb",
   "metadata": {},
   "source": [
    "Build the ETH3D pipeline tool to merge Lidar scans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "771281e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "cd ..\n",
    "git clone https://github.com/ETH3D/dataset-pipeline/\n",
    "cd dataset-pipeline/\n",
    "git clone https://github.com/laurentkneip/opengv\n",
    "cd opengv\n",
    "rm -rf build && mkdir build && cd build\n",
    "cmake .. && make -j\n",
    "cd ..\n",
    "rm -rf build && mkdir build/ && cd build\n",
    "ls ../opengv/build\n",
    "cmake -DCMAKE_BUILD_TYPE=RelWithDebInfo -Dopengv_DIR=$(pwd)/../opengv/build ..\n",
    "make -j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50d499a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy_pcd(tag, tol):\n",
    "    sfm_dir = Paths().interpolate(scene=scene, outputs=outputs_path, method=keypoints, tag=tag).triangulation\n",
    "    accuracy_path = sfm_dir / f\"reconstruction.tolerance_{tol}.ply\"\n",
    "    cmd = [\n",
    "        \"../multi-view-evaluation/build/ETH3DMultiViewEvaluation\",\n",
    "        \"--ground_truth_mlp_path\", str(scan_path),\n",
    "        \"--reconstruction_ply_path\", str(sfm_dir/\"reconstruction.ply\"),\n",
    "        \"--tolerances\", str(tol),\n",
    "        \"--accuracy_cloud_output_path\", str(sfm_dir/\"reconstruction\"),\n",
    "    ]\n",
    "    subprocess.run(cmd, check=True)\n",
    "    assert accuracy_path.exists()\n",
    "    return sfm_dir, accuracy_path\n",
    "\n",
    "def get_error_tl(sfm_path, accuracy_path):\n",
    "    # extract the accuracy labels from the evaluation outputs\n",
    "    pcd_acc = open3d.io.read_point_cloud(str(accuracy_path))\n",
    "    colors = np.asarray(pcd_acc.colors)\n",
    "    good = np.array([0, 1, 0])\n",
    "    bad = np.array([1, 0, 0])\n",
    "    ignore = np.array([0, 0, 1])\n",
    "    is_ignore = np.all(colors == ignore, -1)\n",
    "    is_good = np.all(colors == good, -1)\n",
    "    print(is_good[~is_ignore].mean())\n",
    "\n",
    "    # compute the SfM track lengths\n",
    "    rec = pycolmap.Reconstruction(sfm_path)\n",
    "    pids = np.array(sorted(rec.points3D))\n",
    "    p3ds = np.array([rec.points3D[i].xyz for i in pids])\n",
    "    all_tl = np.array([rec.points3D[i].track.length() for i in pids])\n",
    "\n",
    "    # propagate the accuracy labels from the evaluation pcd to the SfM IDs\n",
    "    dist_sfm_acc = cdist(p3ds, np.asarray(pcd_acc.points))\n",
    "    assert np.all(np.min(dist_sfm_acc, 1) < 1e-5)\n",
    "    nearest = np.argmin(dist_sfm_acc, 1)\n",
    "    pid2acc = {pids[idx]: i for idx, i in enumerate(nearest)}\n",
    "    valid = np.array([not is_ignore[pid2acc[i]] for i in pids])\n",
    "\n",
    "    # compute the 3D error w.r.t the GT Lidar pointcloud\n",
    "    p3ds_gt = []\n",
    "    for p3d in tqdm(p3ds):\n",
    "        _, idx_gt, _ = tree_eval.search_knn_vector_3d(p3d, 1)\n",
    "        p3d_gt = pcd_eval.points[idx_gt[0]]\n",
    "        p3ds_gt.append(p3d_gt)\n",
    "    p3ds_gt = np.array(p3ds_gt)\n",
    "    all_errs_3d = np.linalg.norm(p3ds - p3ds_gt, axis=-1)\n",
    "\n",
    "    return all_tl[valid], all_errs_3d[valid]\n",
    "\n",
    "scan_path = dataset_path / scene / \"dslr_scan_eval/scan_alignment.mlp\"\n",
    "scan_merged_path = scan_path.parent / \"merged.ply\"\n",
    "!../dataset-pipeline/build/NormalEstimator -i $scan_path -o $scan_merged_path\n",
    "pcd_eval = open3d.io.read_point_cloud(str(scan_merged_path))\n",
    "tree_eval = open3d.geometry.KDTreeFlann(pcd_eval)\n",
    "tol = 0.01\n",
    "\n",
    "sfm_raw, acc_raw = compute_accuracy_pcd(tag_raw, tol)\n",
    "tl_raw, errs_3d_raw = get_error_tl(sfm_raw, acc_raw)\n",
    "print(tag_raw, np.mean(errs_3d_raw < tol))\n",
    "\n",
    "sfm_ref, acc_ref = compute_accuracy_pcd(tag_refined, tol)\n",
    "tl_ref, errs_3d_ref = get_error_tl(sfm_ref, acc_ref)\n",
    "print(tag_refined, np.mean(errs_3d_ref < tol))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e1b25f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_tl = 8\n",
    "min_err = 2e-3\n",
    "fig = go.Figure()\n",
    "\n",
    "fig.add_shape(\n",
    "    type='line', x0=0., y0=min_err, x1=1., y1=min_err,\n",
    "    line=dict(color='Gray', dash='dot'),\n",
    "    yref='y', xref= 'paper', \n",
    ")\n",
    "\n",
    "fig.add_annotation(x=-.004, y=np.log10(2e-3),\n",
    "            text=\"Lidar\", xref=\"paper\",\n",
    "           xanchor=\"right\", yanchor='middle',\n",
    "            showarrow=False)\n",
    "\n",
    "fig.add_trace(go.Box(\n",
    "    y=np.maximum(errs_3d_raw, min_err),\n",
    "    x=np.minimum(tl_raw, max_tl),\n",
    "    name='raw',\n",
    "    boxpoints=False, jitter=0.3,\n",
    "))\n",
    "\n",
    "fig.add_trace(go.Box(\n",
    "    y=np.maximum(errs_3d_ref, min_err),\n",
    "    x=np.minimum(tl_ref, max_tl),\n",
    "    name='refined',\n",
    "    boxpoints=False, jitter=0.3,\n",
    "))\n",
    "\n",
    "bins = np.unique(np.minimum(tl_raw, max_tl))\n",
    "ax = dict(linewidth=1, linecolor='black', mirror=True, gridcolor='rgb(210, 210, 210)')\n",
    "fig.update_layout(\n",
    "    yaxis=dict(\n",
    "        type=\"log\",\n",
    "        tickvals=[1e-3, 1e-2, 1e-1, 1],\n",
    "        ticktext=['1mm', '1cm', '10cm', '1m'],\n",
    "        range=[-3.1, 0.9],\n",
    "        **ax,\n",
    "    ),\n",
    "    xaxis = dict(\n",
    "        tickmode = 'array', tickvals = bins,\n",
    "        ticktext = list(map(str, bins[:-1])) + [str(bins[-1])+'+'],\n",
    "        **ax,\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.update_layout(\n",
    "    xaxis_title='track length',\n",
    "    boxmode='group',\n",
    "    paper_bgcolor='rgba(0,0,0,0)',\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    "    legend=dict(\n",
    "        orientation=\"v\",\n",
    "        yanchor=\"top\",\n",
    "        y=0.98,\n",
    "        xanchor=\"right\",\n",
    "        x=0.93,\n",
    "    ),\n",
    "    font=dict(\n",
    "        size=18,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    title=dict(\n",
    "        text='3D triangulation error',\n",
    "        y=0.99,\n",
    "        x=0.54,\n",
    "        xanchor='center',\n",
    "        yanchor='top',\n",
    "    ),\n",
    "    width=700,\n",
    "    height=600,\n",
    "    margin=go.layout.Margin(\n",
    "        l=0, #left margin\n",
    "        r=0, #right margin\n",
    "        b=0, #bottom margin\n",
    "        t=35, #top margin\n",
    "    ),\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "# fig.write_image(\"plots/boxplot_error_vs_trancklength.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afb52f3e",
   "metadata": {},
   "source": [
    "# Distribution of point displacements"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "id": "6fd17e48",
   "metadata": {},
   "outputs": [],
   "source": [
    "rec = pycolmap.Reconstruction(sfm_ref)\n",
    "rec_raw = pycolmap.Reconstruction(sfm_raw)\n",
    "\n",
    "def gather_projections(images, points3D):\n",
    "    im2xy = {}\n",
    "    for imid, image in rec.images.items():\n",
    "        cam = rec.cameras[image.camera_id]\n",
    "        for idx, p2d in enumerate(image.points2D):\n",
    "            if not p2d.has_point3D():\n",
    "                continue\n",
    "            p3d = rec.points3D[p2d.point3D_id]\n",
    "            xy, = image.project([p3d])\n",
    "            xy = cam.world_to_image(xy)\n",
    "            im2xy[(imid, idx)] = xy\n",
    "    return im2xy\n",
    "\n",
    "def gather_detections(images):\n",
    "    im2xy = {}\n",
    "    for imid, image in rec.images.items():\n",
    "        for idx, p2d in enumerate(image.points2D):\n",
    "            if not p2d.has_point3D():\n",
    "                continue\n",
    "            im2xy[(imid, idx)] = p2d.xy\n",
    "    return im2xy\n",
    "\n",
    "im2xy_ref = gather_projections(images, points3D)\n",
    "im2xy_raw = gather_detections(images_raw)\n",
    "# im2xy_raw = gather_projections(images_raw, points3D_raw)  # uncomment for movement vs geometric refinement\n",
    "# im2xy_raw = gather_detections(images)  # uncomment for movement of FBA only\n",
    "\n",
    "keys = list(im2xy_ref.keys() & im2xy_raw.keys())\n",
    "err = np.array([np.linalg.norm(im2xy_ref[k] - im2xy_raw[k]) for k in keys])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "933deb0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "v = np.sort(err)\n",
    "cum = np.arange(len(v)) / (len(v)-1) * 100\n",
    "# TODO: subsample based on density\n",
    "\n",
    "fig = go.Figure(data=go.Scatter(\n",
    "    x=v, y=cum,\n",
    "    fill='tozeroy',\n",
    "    fillcolor='rgba(99, 110, 250, .2)',\n",
    "))\n",
    "\n",
    "ax = dict(linewidth=1, linecolor='black', mirror=True, gridcolor='rgb(210, 210, 210)')\n",
    "fig.update_layout(\n",
    "    yaxis=dict(\n",
    "        range=[0, 100],\n",
    "        **ax,\n",
    "    ),\n",
    "    xaxis = dict(\n",
    "        type='log',\n",
    "        range=np.log10([0.2, 12]).tolist(),\n",
    "        **ax,\n",
    "    )\n",
    ")\n",
    "\n",
    "fig.update_layout(\n",
    "    xaxis_title='point movement [px]',\n",
    "    yaxis_title='cumulative distr. [%]',\n",
    "    paper_bgcolor='rgba(0,0,0,0)',\n",
    "    plot_bgcolor='rgba(0,0,0,0)',\n",
    "    font=dict(\n",
    "        size=18,\n",
    "        color=\"Black\"\n",
    "    ),\n",
    "    title=dict(\n",
    "        text='Distribution of 2D point movement',\n",
    "        y=0.99,\n",
    "        x=0.54,\n",
    "        xanchor='center',\n",
    "        yanchor='top',\n",
    "    ),\n",
    "    width=720,\n",
    "    height=400,\n",
    "    margin=go.layout.Margin(\n",
    "        l=0, #left margin\n",
    "        r=0, #right margin\n",
    "        b=0, #bottom margin\n",
    "        t=35, #top margin\n",
    "    ),\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "# fig.write_image(\"plots/cumulative_movement.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
